#!/usr/bin/env python3

import os
import re
import socket
import subprocess
import sys
import threading
import traceback
import urllib.request
from http.server import BaseHTTPRequestHandler, HTTPServer
from socketserver import ThreadingMixIn


def is_ipv6(host):
    try:
        socket.inet_aton(host)
        return False
    except:
        return True


def get_local_port(host, ipv6):
    if ipv6:
        family = socket.AF_INET6
    else:
        family = socket.AF_INET

    with socket.socket(family) as fd:
        fd.bind((host, 0))
        return fd.getsockname()[1]


CLICKHOUSE_HOST = os.environ.get("CLICKHOUSE_HOST", "localhost")
CLICKHOUSE_PORT_HTTP = os.environ.get("CLICKHOUSE_PORT_HTTP", "8123")

# Server returns this JSON response.
SERVER_JSON_RESPONSE = """{
	"login": "ClickHouse",
	"id": 54801242,
	"name": "ClickHouse",
	"company": null
}"""

PAYLOAD_LEN = len(SERVER_JSON_RESPONSE)

EXPECTED_ANSWER = """{\\n\\t"login": "ClickHouse",\\n\\t"id": 54801242,\\n\\t"name": "ClickHouse",\\n\\t"company": null\\n}"""

#####################################################################################
# This test starts an HTTP server and serves data to clickhouse url-engine based table.
# The objective of this test is to check the ClickHouse server provides a User-Agent
# with HTTP requests.
# In order for it to work ip+port of http server (given below) should be
# accessible from clickhouse server.
#####################################################################################

# IP-address of this host accessible from the outside world. Get the first one
HTTP_SERVER_HOST = (
    subprocess.check_output(["hostname", "-i"]).decode("utf-8").strip().split()[0]
)
IS_IPV6 = is_ipv6(HTTP_SERVER_HOST)
HTTP_SERVER_PORT = get_local_port(HTTP_SERVER_HOST, IS_IPV6)

# IP address and port of the HTTP server started from this script.
HTTP_SERVER_ADDRESS = (HTTP_SERVER_HOST, HTTP_SERVER_PORT)
if IS_IPV6:
    HTTP_SERVER_URL_STR = (
        "http://"
        + f"[{str(HTTP_SERVER_ADDRESS[0])}]:{str(HTTP_SERVER_ADDRESS[1])}"
        + "/"
    )
else:
    HTTP_SERVER_URL_STR = (
        "http://" + f"{str(HTTP_SERVER_ADDRESS[0])}:{str(HTTP_SERVER_ADDRESS[1])}" + "/"
    )


def get_ch_answer(query):
    host = CLICKHOUSE_HOST
    if IS_IPV6:
        host = f"[{host}]"

    url = os.environ.get(
        "CLICKHOUSE_URL",
        "http://{host}:{port}".format(host=CLICKHOUSE_HOST, port=CLICKHOUSE_PORT_HTTP),
    )
    return urllib.request.urlopen(url, data=query.encode()).read().decode()


def check_answers(query, answer):
    ch_answer = get_ch_answer(query)
    if ch_answer.strip() != answer.strip():
        print("FAIL on query:", query, file=sys.stderr)
        print("Expected answer:", answer, file=sys.stderr)
        print("Fetched answer :", ch_answer, file=sys.stderr)
        raise Exception("Fail on query")


BYTE_RANGE_RE = re.compile(r"bytes=(\d+)-(\d+)?$")


def parse_byte_range(byte_range):
    """Returns the two numbers in 'bytes=123-456' or throws ValueError.
    The last number or both numbers may be None.
    """
    if byte_range.strip() == "":
        return None, None

    m = BYTE_RANGE_RE.match(byte_range)
    if not m:
        raise ValueError(f"Invalid byte range {byte_range}")

    first, last = [x and int(x) for x in m.groups()]
    if last and last < first:
        raise ValueError(f"Invalid byte range {byte_range}")
    return first, last


# Server with check for User-Agent headers.
class HttpProcessor(BaseHTTPRequestHandler):
    allow_range = False
    range_used = False
    get_call_num = 0
    responses_to_get = []

    def send_head(self, from_get=False):
        if self.headers["Range"] and HttpProcessor.allow_range:
            try:
                self.range = parse_byte_range(self.headers["Range"])
            except ValueError as e:
                self.send_error(400, "Invalid byte range")
                return None
        else:
            self.range = None

        if self.range:
            first, last = self.range
        else:
            first, last = None, None

        if first == None:
            first = 0

        payload = SERVER_JSON_RESPONSE.encode()
        payload_len = len(payload)
        if first and first >= payload_len:
            self.send_error(416, "Requested Range Not Satisfiable")
            return None

        retry_range_request = (
            first != 0 and from_get is True and len(HttpProcessor.responses_to_get) > 0
        )
        if retry_range_request:
            code = HttpProcessor.responses_to_get.pop()
            if code not in HttpProcessor.responses:
                self.send_response(int(code))
        else:
            self.send_response(206 if HttpProcessor.allow_range else 200)

        self.send_header("Content-type", "application/json")

        if HttpProcessor.allow_range:
            self.send_header("Accept-Ranges", "bytes")

        if last is None or last >= payload_len:
            last = payload_len - 1

        response_length = last - first + 1

        if first or last:
            self.send_header("Content-Range", f"bytes {first}-{last}/{payload_len}")
        self.send_header(
            "Content-Length",
            str(response_length) if HttpProcessor.allow_range else str(payload_len),
        )
        self.end_headers()
        return payload

    def do_HEAD(self):
        self.send_head()

    def do_GET(self):
        result = self.send_head(True)
        if result == None:
            return

        HttpProcessor.get_call_num += 1

        if not self.range:
            self.wfile.write(SERVER_JSON_RESPONSE.encode())
            return

        HttpProcessor.range_used = True
        payload = SERVER_JSON_RESPONSE.encode()
        start, stop = self.range
        if stop == None:
            stop = len(payload) - 1
        if start == None:
            start = 0
        self.wfile.write(SERVER_JSON_RESPONSE.encode()[start : stop + 1])

    def log_message(self, format, *args):
        return


class HTTPServerV6(HTTPServer):
    address_family = socket.AF_INET6


class ThreadedHTTPServer(ThreadingMixIn, HTTPServer):
    pass


class ThreadedHTTPServerV6(ThreadingMixIn, HTTPServerV6):
    pass


def start_server():
    if IS_IPV6:
        httpd = ThreadedHTTPServerV6(HTTP_SERVER_ADDRESS, HttpProcessor)
    else:
        httpd = ThreadedHTTPServer(HTTP_SERVER_ADDRESS, HttpProcessor)

    t = threading.Thread(target=httpd.serve_forever)
    t.start()
    return t, httpd


#####################################################################
# Testing area.
#####################################################################


def test_select(settings):
    global HTTP_SERVER_URL_STR
    query = f"SELECT * FROM url('{HTTP_SERVER_URL_STR}','JSONAsString') SETTINGS {','.join((k+'='+repr(v) for k, v in settings.items()))};"
    check_answers(query, EXPECTED_ANSWER)


def run_test(allow_range, settings, check_retries=False):
    HttpProcessor.range_used = False
    HttpProcessor.get_call_num = 0
    HttpProcessor.allow_range = allow_range
    if check_retries:
        HttpProcessor.responses_to_get = ["500", "200", "206"]
    retries_num = len(HttpProcessor.responses_to_get)

    test_select(settings)

    download_buffer_size = settings["max_download_buffer_size"]
    expected_get_call_num = (PAYLOAD_LEN - 1) // download_buffer_size + 1
    if allow_range:
        if not HttpProcessor.range_used:
            raise Exception("HTTP Range was not used when supported")

        if check_retries and len(HttpProcessor.responses_to_get) > 0:
            raise Exception(
                "Expected to get http response 500, which had to be retried, but 200 ok returned and then retried"
            )

        if retries_num > 0:
            expected_get_call_num += retries_num - 1

        if expected_get_call_num != HttpProcessor.get_call_num:
            raise Exception(
                f"Invalid amount of GET calls with Range. Expected {expected_get_call_num}, actual {HttpProcessor.get_call_num}"
            )
    else:
        if HttpProcessor.range_used:
            raise Exception("HTTP Range used while not supported")

    print("PASSED")


def main():
    t, httpd = start_server()

    settings = {"max_download_buffer_size": 20}

    # Test Accept-Ranges=False
    run_test(allow_range=False, settings=settings)
    # Test Accept-Ranges=True, parallel download is used
    run_test(allow_range=True, settings=settings)

    # Test Accept-Ranges=True, parallel download is used
    settings = {"max_download_buffer_size": 10}
    run_test(allow_range=True, settings=settings)

    # Test Accept-Ranges=True, parallel download is not used,
    # first get request 500 response,
    # second get request 200ok response,
    # third get request (retry) 206 response.
    settings["max_download_threads"] = 2
    run_test(allow_range=True, settings=settings, check_retries=True)

    httpd.shutdown()
    t.join()


if __name__ == "__main__":
    try:
        main()
        sys.stdout.flush()
        os._exit(0)
    except Exception as ex:
        exc_type, exc_value, exc_traceback = sys.exc_info()
        traceback.print_tb(exc_traceback, file=sys.stderr)
        print(ex, file=sys.stderr)
        sys.stderr.flush()

        os._exit(1)
